import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from linearmodels import PanelOLS
from matplotlib.ticker import FuncFormatter

from cat_rfs.code.utils.tools import save_stat
from cat_rfs.code.utils.table import Table


def print_fig4(output='console'):
    """
    Print Figure 4: Impulse-responses of selected macroeconomic variables to natural disasters occurring in year 0.
    """
    df = pd.read_csv('cat_rfs/data/macro_effects.csv')

    # Clean data
    df = df.rename(columns={'rgdpbarro': 'rgdppc', 'rconsbarro': 'rconpc'})

    df.loc[df['iso'] == 'AUS', 'gdp'] /= 1000  # data in millions
    df.loc[df['iso'] == 'BEL', 'gdp'] /= 1000
    df.loc[df['iso'] == 'FIN', 'gdp'] /= 1000
    df.loc[df['iso'] == 'JPN', 'gdp'] *= 1000  # data in trillions
    df.loc[df['iso'] == 'NLD', 'gdp'] /= 1000
    df.loc[df['iso'] == 'NOR', 'gdp'] /= 1000
    df.loc[df['iso'] == 'PRT', 'gdp'] /= 1000
    df.loc[df['iso'] == 'ESP', 'gdp'] /= 1000
    df.loc[df['iso'] == 'SWE', 'gdp'] /= 1000
    df.loc[df['iso'] == 'CHE', 'gdp'] /= 1000

    df['gdp_usd'] = df['gdp'] / df['xrusd']

    df['total_damage'] = df['total_damage'].fillna(value=0)

    df = df.sort_values(['iso', 'year'])

    df['hpnom'] = df['hpnom'] / df['cpi']

    for h in range(0, 4):
        df[f'dC{h}'] = np.log(df.groupby('iso')['rconpc'].shift(-h) / df.groupby('iso')['rconpc'].shift(1)) * 100
        df[f'dGDP{h}'] = np.log(df.groupby('iso')['rgdppc'].shift(-h) / df.groupby('iso')['rgdppc'].shift(1)) * 100
        df[f'dHpnom{h}'] = np.log(df.groupby('iso')['hpnom'].shift(-h) / df.groupby('iso')['hpnom'].shift(1)) * 100
        df[f'r{h}'] = df.groupby('iso')['eq_tr'].shift(-h)

    df['ret0'] = np.log(1 + df['r0']) * 100
    df['ret1'] = np.log(1 + df['r0'] + df['r1']) * 100
    df['ret2'] = np.log(1 + df['r0'] + df['r1'] + df['r2']) * 100
    df['ret3'] = np.log(1 + df['r0'] + df['r1'] + df['r2'] + df['r3']) * 100

    df['dmg_perc_gdp'] = (df['total_damage'] / 1000) / (df.groupby('iso')['gdp_usd'].shift(1) * 1000)

    BINS = [[0.002, 0.01],
            [0.01, 1]]

    for n, i in enumerate(BINS):
        df[f'I{n}'] = 0
        df.loc[(df['dmg_perc_gdp'] > i[0]) & (df['dmg_perc_gdp'] <= i[1]), f'I{n}'] = 1

    df = df[df['year'] >= 1950]
    df = df[df['year'] <= 2017]

    save_stat(df['I1'].sum(), 'n_large_events', formatting='%.0f', output=output)
    save_stat(len(df['iso'].unique()), 'n_countries', formatting='%.0f', output=output)

    df = df.set_index(['iso', 'year'])

    YVARS = {'Consumption growth': 'dC', 'GDP growth': 'dGDP', 'House price growth': 'dHpnom',
             'Stock market return': 'ret'}

    size = (4.0, 3.3) if output == 'presentation' else (6.6, 8)
    lw = 0.25 if output == 'presentation' else 0.5

    nrows = 2 if output == 'presentation' else 4
    fig, ax = plt.subplots(nrows=nrows, ncols=2, figsize=size)
    if output != 'presentation':
        cols = ['Small disaster', 'Large disaster']
        rows = list(YVARS.keys())

        for a, col in zip(ax[0], cols):
            a.set_title(col)

        for a, row in zip(ax[:, 0], rows):
            a.set_ylabel(row)

    T = Table(9, width='auto', justs=['l', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd'],
              caption='Effect of natural disasters on selected macroeconomic variables',
              label='macro_effects')
    n = 0
    results = []
    other_stats = [['FE', '$N$', '$R^2$']]
    names = []
    for name, i in YVARS.items():
        params = []
        lb = []
        ub = []

        for h in range(0, 4):
            yvar = i + str(h)
            df2 = df[df[yvar].notnull()].copy()

            formula = f'{yvar} ~ 1 + I0 + I1 + EntityEffects + TimeEffects'
            ols = PanelOLS.from_formula(formula, df2).fit(cov_type='clustered', cluster_time=True)
            if h in (0, 2):
                results.append([ols.params[1], ols.std_errors[1], ols.params[2], ols.std_errors[2]])
                other_stats.append(['Yes', int(ols.nobs), ols.rsquared])

                ols2 = PanelOLS.from_formula(f'{yvar} ~ 1 + I0 + I1', df2).fit(cov_type='clustered', cluster_time=True)

                results.append([ols2.params[1], ols2.std_errors[1], ols2.params[2], ols2.std_errors[2]])
                other_stats.append(['No', int(ols2.nobs), ols2.rsquared])

            params.extend([ols.params[1]])
            lb.extend([ols.conf_int()['lower']['I0']])
            ub.extend([ols.conf_int()['upper']['I0']])
        std = df[i + '0'].std() * 2
        row = n // 2 if output == 'presentation' else n
        col = n % 2 if output == 'presentation' else 0
        title = name if output == 'presentation' else None
        pd.DataFrame(params)[0].plot(title=title, ax=ax[row][col], color='k')
        ax[row][col].axhline(y=0, linewidth=lw, color='k')
        ax[row][col].set_xticks([0, 1, 2, 3])
        ax[row][col].set_ylim([-std, std])
        ax[row][col].set_xlim([0, 3])
        ax[row][col].fill_between(range(0, 4), ub, lb, facecolor='lightgray', linewidth=0.0)
        ax[row][col].yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.0%}'.format(y / 100)))
        names += [name]
        if n in (1, 3):
            data = pd.DataFrame(results).T
            data.insert(0, 'desc', ['$Small_{i,t}$', '', '$Large_{i,t}$', ''])

            T.add_panel(data, formatting="%.3f",
                        supheaders=[''] + names, supheader_cols=[1, (2, 5), (6, 9)],
                        headers=['', '$h=0$', '$h=2$', '$h=0$', '$h=2$'],
                        header_cols=[1, (2, 3), (4, 5), (6, 7), (8, 9)],
                        regression_specs={'stars': (0.1, 0.05, 0.01), 'dof': df2.shape[0]})

            T.add_panel(other_stats, formatting="%.3f")

            results = []
            other_stats = [['FE', '$N$', '$R^2$']]
            names = []

        n += 1

    if output == 'paper':
        T.print_table('cat_rfs/output/tables/tablea14.tex')
    else:
        T.print_table()

    n = 0
    for name, i in YVARS.items():
        params = []
        lb = []
        ub = []
        for h in range(0, 4):
            yvar = i + str(h)
            df2 = df[df[yvar].notnull()].copy()
            formula = f'{yvar} ~ 1 + I0 + I1 + EntityEffects + TimeEffects'

            ols = PanelOLS.from_formula(formula, df2).fit(cov_type='clustered', cluster_time=True)

            params.extend([ols.params[2]])
            lb.extend([ols.conf_int()['lower']['I1']])
            ub.extend([ols.conf_int()['upper']['I1']])
        std = df[i + '0'].std() * 2
        row = n // 2 if output == 'presentation' else n
        col = n % 2 if output == 'presentation' else 1
        title = name if output == 'presentation' else None
        pd.DataFrame(params)[0].plot(title=title, ax=ax[row][col], color='k')
        ax[row][col].axhline(y=0, linewidth=lw, color='k')
        ax[row][col].set_xticks([0, 1, 2, 3])
        ax[row][col].set_ylim([-std, std])
        ax[row][col].set_xlim([0, 3])
        ax[row][col].fill_between(range(0, 4), ub, lb, facecolor='lightgray', linewidth=0.0)
        ax[row][col].yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.0%}'.format(y / 100)))
        n += 1

        fig.tight_layout()
        plt.subplots_adjust(right=0.99)
    if output == 'paper':
        plt.savefig('cat_rfs/output/figures/fig4.eps')
        plt.close()
    else:
        plt.show()

    pass
